function [X,Y,D,n,ps,Yscale,x1_wave,prevuse_wave,Xscale] = generate_inputs_cubic_rule_reweight(sample,sample_full,wave,tcost,covariate,SMC)
% This function generates the inputs for the EWM cubic rule search
% algorithm, specifically for the cross waves analysis. 

% Input
% (1) sample: cross-sectional, wave-specific sample
% (2) sample_full: cross-sectional, pooled sample
% (2) wave: indicates whether the output is wave-specific (=3,6,or 7)
% (3) tcost: program cost per household. >0 represents cost savings; =0
% represents kWh reduction
% (4) covariate: indicates covariate for analysis: income, size, vintage,
% minimum of baseline consumption, maximum of baseline consumption, or
% standard deviation of consumption
% (5)  SMC: =1 use social marginal cost; =0 use retail electricity price

% Output
% (1) X: the list of covariates for EWM analysis, typically baseline
% consumption & the selected covariates
% (2) Y: demeaned and scaled cost of electricity consumption, net of
% program cost
% (3) D: actual Opower treatment status in RCT
% (4) n: number of households
% (5) ps: propensity scores
% (6) Yscale: scale factor used to normalize y
% (7) x1_wave: scaled covariates
% (8) prevuse_wave: scaled baseline consumption
% (9) Xscale: scale factor used to normalize X

%% Define variables
if wave==0 % pooled 
    usage_wave = sample.post_avg - sample.pre_avg;
    opower_wave = sample.opower;
    inc_wave = sample.income;
    size_wave = sample.uni_size;
    vintage_wave = sample.vintage;
    prevuse_wave = round(sample.pre_avg);
    
    if covariate=="min" || covariate=="max" || covariate=="std"
        min_baseline_wave = sample.min_baseline;
        max_baseline_wave = sample.max_baseline;
        std_baseline_wave = sample.std_baseline;
    end
else % wave-specific
    usage_wave = table2array(sample(sample.opower_paper_wave==wave,{'post_avg'}))-...
        table2array(sample(sample.opower_paper_wave==wave,{'pre_avg'}));
    opower_wave = table2array(sample(sample.opower_paper_wave==wave,{'opower'}));
    inc_wave = table2array(sample(sample.opower_paper_wave==wave,{'income'}));
    size_wave = table2array(sample(sample.opower_paper_wave==wave,{'uni_size'}));
    vintage_wave = table2array(sample(sample.opower_paper_wave==wave,{'vintage'}));
    prevuse_wave = round(table2array(sample(sample.opower_paper_wave==wave,{'pre_avg'})));
    min_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'min_baseline'}));
    max_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'max_baseline'}));
    std_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'std_baseline'}));
end

%% Generate full sample discretization
% We need the same values after discretization for full sample (or wave sample)
% and wave-specification sample (training/testing set within wave) to
% generate density ratios
% So here we need the discretized values from the full smaple and use the
% same values for wave specific analysis

usage_full = sample_full.post_avg - sample_full.pre_avg;
opower_full = sample_full.opower;
inc_full = sample_full.income;
size_full = sample_full.uni_size;
vintage_full = sample_full.vintage;
prevuse_full = round(sample_full.pre_avg);

if covariate=="min" || covariate=="max" || covariate=="std"
    min_baseline_full = sample_full.min_baseline;
    max_baseline_full = sample_full.max_baseline;
    std_baseline_full = sample_full.std_baseline;
end

%% income no discretization needed 
% but need to define max income using full sample 
inc_max=max(inc_full);

%% discretize prevuse
cat_use = 30;
[~,edge_full] = discretize(prevuse_full,cat_use); % discretize full sample to find edges
prevuse_cat_wave = discretize(prevuse_wave,edge_full); % use full sample edge to discretize wave sample

prevuse_dcrt = nan(size(prevuse_wave,1),1);
num = size(prevuse_dcrt,1);
for i = 1:num
    prevuse_dcrt(i,1) = edge_full(1,prevuse_cat_wave(i,1)); % use the lower edge of each bin for each discrete value
end

prevuse_max=max(edge_full); % get the max discretized value from full sample, will use in the Xscale part

%% discretize unit size
if covariate=="income" || covariate=="size" || covariate=="vintage"
    
    cat_size = 30;
    [~,edge_full] = discretize(size_full,cat_size);
    size_cat_wave = discretize(size_wave,edge_full);
    
    size_dcrt = nan(size(size_wave,1),1);
    num = size(size_dcrt,1);
    for i = 1:num
        size_dcrt(i,1) = edge_full(1,size_cat_wave(i,1));
    end
    
    size_max=max(edge_full);
    
    %% discretize vintage
    cat_vintage = 30;
    [~,edge_full] = discretize(vintage_full,cat_vintage);
    vintage_cat_wave = discretize(vintage_wave,edge_full);
    
    vintage_dcrt = nan(size(vintage_wave,1),1);
    num = size(vintage_dcrt,1);
    for i = 1:num
        vintage_dcrt(i,1) = edge_full(1,vintage_cat_wave(i,1));
    end
    
    vintage_max=max(edge_full);
    
end
%% discretize min/max/std previous cosumption
if covariate=="min"
    cat_use = 30;
    [~,edge_full] = discretize(min_baseline_full(:,1),cat_use);
    min_cat_wave = discretize(min_baseline_wave,edge);
    
    min_dcrt = nan(size(min_baseline_wave,1),1);
    num = size(min_dcrt,1);
    for i = 1:num
        min_dcrt(i,1) = edge_full(1,min_cat_wave(i,1));
    end
    
    min_max=max(edge_full);
    
elseif covariate=="max"
    [~,edge_full] = discretize(max_baseline_full(:,1),cat_use); 
    max_cat_wave = discretize(max_baseline_wave,edge);
    
    max_dcrt = nan(size(max_baseline_wave,1),1);
    num = size(max_dcrt,1);
    for i = 1:num
        max_dcrt(i,1) = edge_full(1,max_cat_wave(i,1)); 
    end
    
    max_max=max(edge_full);
    
elseif covariate=="std"
    [~,edge_full] = discretize(std_baseline_full(:,1),cat_use); 
    std_cat_wave = discretize(std_baseline_wave,edge);
    
    std_dcrt = nan(size(std_baseline_wave,1),1);
    num = size(std_dcrt,1);
    for i = 1:num
        std_dcrt(i,1) = edge_full(1,std_cat_wave(i,1)); 
    end
    
    std_max=max(edge_full);
end

%% Y in terms of cost savings or energy conservation
if (SMC)
    if (tcost)
        Y_all = 0.065.*usage_wave(:,1)+tcost.*opower_wave(:,1);
    else
        Y_all = usage_wave(:,1);
    end
else
    if (tcost)
        Y_all = 0.177.*usage_wave(:,1)+tcost.*opower_wave(:,1);
    else
        Y_all = usage_wave(:,1);
    end
end

if wave==0 % in the pooled sample
    demean_wave = 1;  %=1 demean within wave
    [Y_all] = Y_demean(Y_all,sample,demean_wave);
else % wave-specific analysis, always demean by "full" wave sample
    demean_wave = 0;  
    [Y_all] = Y_demean(Y_all,sample,demean_wave);
end

%% Define Y, Yscale, D, n, and ps as final inputs
Yscale = max(abs(Y_all));
Y = Y_all./Yscale; % rescale demeaned outcomes to [-1,1]
D = opower_wave;
n = length(Y); 
if wave==0
    ps = sample.ps;
else
    ps=sum(opower_wave)/n; % make sure propensity score is calculated based on wave-specific data
end

%% Define X, Xscale
switch covariate
    case 'income'
        covariates  = [inc_wave inc_wave.^2 inc_wave.^3 prevuse_dcrt];
        covariates_max= [inc_max inc_max.^2 inc_max.^3 prevuse_max];
        x1_wave = inc_wave;
    case 'size'
        covariates  = [size_dcrt size_dcrt.^2 size_dcrt.^3 prevuse_dcrt];
        covariates_max  = [size_max size_max.^2 size_max.^3 prevuse_max];
        x1_wave = size_wave;
    case 'vintage'
        covariates  = [vintage_dcrt vintage_dcrt.^2 vintage_dcrt.^3 prevuse_dcrt];
        covariates_max  = [vintage_max vintage_max.^2 vintage_max.^3 prevuse_max];
        x1_wave = vintage_wave;
    case "min"
        covariates  = [min_dcrt min_dcrt.^2 min_dcrt.^3 prevuse_dcrt];
        covariates_max  = [min_max min_max.^2 min_max.^3 prevuse_max];
        x1_wave = min_baseline_wave;
    case "max"
        covariates  = [max_dcrt max_dcrt.^2 max_dcrt.^3 prevuse_dcrt];
        covariates_max  = [max_max max_max.^2 max_max.^3 prevuse_max];
        x1_wave = max_baseline_wave;
    case "std"
        covariates  = [std_dcrt std_dcrt.^2 std_dcrt.^3 prevuse_dcrt];
        covariates_max  = [std_max std_max.^2 std_max.^3 prevuse_max];
        x1_wave = std_baseline_wave;
end

Xscale = ones(n,1)*covariates_max; % note: max value refers to max in the full sample, not the wave-specific sample
X = [ones(n,1) covariates./Xscale];

end